# Copyright 2018 MLBenchmark Group. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Verification script to check model tag sets.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import defaultdict
import re

from mlperf_compliance import _gnmt_tags
from mlperf_compliance import _ncf_tags
from mlperf_compliance import _resnet_tags
from mlperf_compliance import _ssd_tags
from mlperf_compliance import _transformer_tags
from mlperf_compliance  import tags


_MODEL_TAG_MODULES = [_gnmt_tags, _ncf_tags, _resnet_tags, _ssd_tags,
                      _transformer_tags]

TAG_PATTERN = re.compile("^[A-Za-z0-9_]+$")


def extract_tags(module):
  output = []
  for i in dir(module):
    if i.startswith("_") or not isinstance(getattr(module, i), str):
      continue
    output.append(i)
  return output


def check_collisions():
  defining_modules = defaultdict(list)
  for module in _MODEL_TAG_MODULES:
    name = module.__name__
    for tag in extract_tags(module):
      defining_modules[tag].append(name)

  duplicate_defs = {k: v for k, v in defining_modules.items() if len(v) > 1}
  for key in duplicate_defs.keys():
    print("Variable {} defined multiple times".format(key))


def check_format():
  for tag in sorted(tags.ALL_USED_TAGS):
    if not TAG_PATTERN.match(tag):
      print("Malformed tag: {}".format(tag))


if __name__ == "__main__":
  check_collisions()
  check_format()

